package com.diven.common.tree;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import com.diven.common.tree.common.Constant;
import com.diven.common.tree.model.BinPoints;
import com.diven.common.tree.model.ClassifierTree;
import com.diven.common.tree.model.ConvertTypeModel;
import com.diven.common.tree.model.FeatureBinPoints;
import com.diven.common.tree.model.FeatureSpeculate;
import com.diven.common.tree.model.MissingValueFillModel;
import com.diven.common.tree.model.MissingValueFillType;
import com.diven.common.tree.model.PretreatmentModel;
import com.diven.common.tree.model.RuleInfo;
import com.diven.common.tree.utils.ListUtil;
import com.diven.common.tree.utils.TreeUtil;
import com.diven.common.tree.weka.WekaREPTree;

import weka.core.Attribute;
import weka.core.AttributeStats;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NumericToNominal;
import weka.filters.unsupervised.attribute.ReplaceMissingWithUserConstant;

public class InteractiveDecisionTree implements DecisionTree<ClassifierTree>{
	
	private static final long serialVersionUID = 1528877813868941856L;
	
	/**默认每个叶子的最小实例数*/
	private final static int defaultMinNum = 2;
	/**默认树的最大深度*/
	private final static int defaultMaxDepth = -1;
	/**默认树手动生长的最大深度*/
	private final static int defaultGrowthMaxDepth = 3;
	/**默认树手动生长的线程数*/
	private final static int defaultGrowthThreadNum = 4;
	
	
	/**当前结构树 */
	private ClassifierTree tree;
	/**当前数据集的总量 */
	private Integer instanceCount = 0;
	/**整体样本的逾期率**/
	private int  overdueLabelIndex = -1;
	/**整体样本的逾期率**/
	private double  countOverdueRate = -1;
	
	public InteractiveDecisionTree() {}

	public ClassifierTree getTree() {
		return tree;
	}

	public Integer getInstanceCount() {
		return instanceCount;
	}

	public double getCountOverdueRate() {
		return countOverdueRate;
	}

	public int getOverdueLabelIndex() {
		return overdueLabelIndex;
	}
	
	@Override
	public ClassifierTree getClassifierTree() {
		return tree;
	}

	@Override
	public Instances pretreatment(Instances instances, PretreatmentModel pretreatmentModel) throws Exception{
		//没有规则，创建一个实例
		if(pretreatmentModel == null) {
			pretreatmentModel = new PretreatmentModel();
		}
		//过滤特征
		List<String> filterAttributes = pretreatmentModel.getFilterAttributes();
		if(!filterAttributes.isEmpty()) {
			List<String> deleteAttributes = new ArrayList<>();
			for(int index=0; index<instances.numAttributes(); index++) {
				if(instances.classIndex()!= index && !filterAttributes.contains(instances.attribute(index).name())){
					deleteAttributes.add(instances.attribute(index).name());
				}
			}
			for(String attribute : deleteAttributes) {
				instances.deleteAttributeAt(instances.attribute(attribute).index());
			}
		}
		//类型转换模型
		List<ConvertTypeModel> convertTypeModels = pretreatmentModel.getConvertTypeModels();
		if(!convertTypeModels.isEmpty()) {
			//仅需处理连续值转离散值
			List<Integer> needconvertIndexs = new ArrayList<>();
			for(ConvertTypeModel convertTypeModel : convertTypeModels) {
				if(Constant.NOMINAL.equals(convertTypeModel.getFeatureType())) {
					Attribute attribute = instances.attribute(convertTypeModel.getFeatureName());
					if(attribute != null && attribute.isNumeric()) {
						needconvertIndexs.add(attribute.index() + 1);
					}
				}
			}
			if(!needconvertIndexs.isEmpty()) {
				NumericToNominal convert= new NumericToNominal();
		        String[] options= new String[2];
		        options[0]="-R";
		        options[1]=ListUtil.join(needconvertIndexs, ",");
		        convert.setOptions(options);
		        convert.setInputFormat(instances);
		        instances = Filter.useFilter(instances, convert);
			}
		}
		//缺失值处理
		List<MissingValueFillModel> missingValueFillModels = pretreatmentModel.getMissingValueFillModels();
		if(missingValueFillModels != null) {
			//获取数据中哪些特征有缺失值
			List<String> hasMissingAttributes = new ArrayList<>();
			for(int i = 0; i < instances.numAttributes(); i++) {
				AttributeStats attributeStats = instances.attributeStats(i);
				if(attributeStats.missingCount > 0) {
					hasMissingAttributes.add(instances.attribute(i).name());
				}
	        }
			//获取需要转换的特征
			List<String> needFillModelAttributes = new ArrayList<>();
			List<MissingValueFillModel>needFillModels = new ArrayList<>();
			for(MissingValueFillModel missingValueFillModel : missingValueFillModels) {
				if(hasMissingAttributes.contains(missingValueFillModel.getFeatureName())) {
					Attribute attribute = instances.attribute(missingValueFillModel.getFeatureName());
					if(
						attribute.isNominal() 
						&& 
						(
							missingValueFillModel.getFillType() == MissingValueFillType.MISSING_VALLUE 
							|| 
							missingValueFillModel.getFillType() == MissingValueFillType.MODE_VALUE
						)
					){
						missingValueFillModel.setFeatureType(Constant.NOMINAL);
						needFillModels.add(missingValueFillModel);
						needFillModelAttributes.add(missingValueFillModel.getFeatureName());
					}
					else if(
						attribute.isNumeric() 
						&& 
						(
							missingValueFillModel.getFillType() == MissingValueFillType.MISSING_VALLUE 
							|| 
							missingValueFillModel.getFillType() == MissingValueFillType.MEAN_VALUE
						)
					){
						missingValueFillModel.setFeatureType(Constant.NUMERIC);
						needFillModels.add(missingValueFillModel);
						needFillModelAttributes.add(missingValueFillModel.getFeatureName());
					}
				}
			}
			//所有为定义的离散特征空值都需要进行转换
			hasMissingAttributes.removeAll(needFillModelAttributes);
			for(String featureName : hasMissingAttributes) {
				if(instances.attribute(featureName).isNominal()) {
					MissingValueFillModel missingValueFillModel = new MissingValueFillModel();
					missingValueFillModel.setFeatureName(featureName);
					missingValueFillModel.setFeatureType(Constant.NOMINAL);
					missingValueFillModel.setFillType(MissingValueFillType.MISSING_VALLUE);
					needFillModels.add(missingValueFillModel);
				}
			}
			
			/**===============对needFillModels进行分类处理================================*/
			//离散型:MissingValue
			List<String> NOMINAL_MISSING_VALLUE = new ArrayList<>();
			for(MissingValueFillModel item : needFillModels) {
				if(Constant.NOMINAL.equals(item.getFeatureType()) && MissingValueFillType.MISSING_VALLUE == item.getFillType()) {
					NOMINAL_MISSING_VALLUE.add(item.getFeatureName());
				}
			}
			if(!NOMINAL_MISSING_VALLUE.isEmpty()) {
				ReplaceMissingWithUserConstant replaceMissingWithUserConstant = new ReplaceMissingWithUserConstant();
				replaceMissingWithUserConstant.setNominalStringReplacementValue(Constant.NOMINAL_NULL);
				replaceMissingWithUserConstant.setAttributes(ListUtil.join(NOMINAL_MISSING_VALLUE, ","));
				replaceMissingWithUserConstant.setInputFormat(instances);
				instances = Filter.useFilter(instances, replaceMissingWithUserConstant);
			}
			//离散型:众数
			for(MissingValueFillModel item : needFillModels) {
				if(Constant.NOMINAL.equals(item.getFeatureType()) && MissingValueFillType.MODE_VALUE == item.getFillType()) {
					Double valueIndex = instances.meanOrMode(instances.attribute(item.getFeatureName()));
					String fillValue = instances.attribute(item.getFeatureName()).value(valueIndex.intValue());
					ReplaceMissingWithUserConstant replaceMissingWithUserConstant = new ReplaceMissingWithUserConstant();
					replaceMissingWithUserConstant.setNominalStringReplacementValue(fillValue);
					replaceMissingWithUserConstant.setAttributes(item.getFeatureName());
					replaceMissingWithUserConstant.setInputFormat(instances);
					instances = Filter.useFilter(instances, replaceMissingWithUserConstant);
				}
			}
			//连续型:MissingValue
			for(MissingValueFillModel item : needFillModels) {
				if(Constant.NUMERIC.equals(item.getFeatureType()) && MissingValueFillType.MISSING_VALLUE == item.getFillType()) {
					//不需要进行处理
				}
			}
			//连续型:均值
			for(MissingValueFillModel item : needFillModels) {
				if(Constant.NUMERIC.equals(item.getFeatureType()) && MissingValueFillType.MEAN_VALUE == item.getFillType()) {
					Double meanValue = instances.meanOrMode(instances.attribute(item.getFeatureName()));
					ReplaceMissingWithUserConstant replaceMissingWithUserConstant = new ReplaceMissingWithUserConstant();
					replaceMissingWithUserConstant.setNumericReplacementValue(Utils.doubleToString(meanValue, 4));
					replaceMissingWithUserConstant.setAttributes(item.getFeatureName());
					replaceMissingWithUserConstant.setInputFormat(instances);
					instances = Filter.useFilter(instances, replaceMissingWithUserConstant);
				}
			}
		}
		//返回数据集
		return instances;
	}
	
	@Override
	public void buildClassifier(Instances instances) throws Exception {
		this.buildClassifier(instances, defaultMaxDepth, defaultMinNum);
	}

	@Override
	public void buildClassifier(Instances instances, int maxDepth, double minNum) throws Exception {
		//初始化树
		WekaREPTree wekaREPTree = new WekaREPTree();
		wekaREPTree.setMinNum(minNum);
		wekaREPTree.setMaxDepth(maxDepth);
		wekaREPTree.setNoPruning(true);
		wekaREPTree.buildClassifier(instances, true);
		//重新构建树
		this.tree = TreeUtil.toClassifierTree(wekaREPTree.getTree());
		//索引编号
		TreeUtil.indexTreeNo(this.tree);
		//获取数据总量
		this.instanceCount = instances.size();
		//获取逾期率
		Attribute classAttribute = instances.classAttribute();
		if(classAttribute.isNominal()) {
			for(int i=0; i<instances.classAttribute().numValues(); i++) {
				String value = instances.classAttribute().value(i);
				if("1".equals(value)) {
					this.overdueLabelIndex = i;
					this.countOverdueRate = this.tree.getClassProbs()[i];
					break;
				}
			}
		}
	}
	
	@Override
	public List<FeatureSpeculate> getSpeculateListByNodeId(Integer nodeNo) throws Exception{
		return this.getSpeculateListByNodeId(nodeNo, defaultGrowthThreadNum, defaultGrowthMaxDepth);
	}

	@Override
	public List<FeatureSpeculate> getSpeculateListByNodeId(Integer nodeNo, int threadNum, int maxDepth) throws Exception {
		//定义结果
		final List<FeatureSpeculate> speculates = Collections.synchronizedList(new ArrayList<>());
		//获取节点数据
		ClassifierTree node = TreeUtil.getNodeByTreeNo(this.tree, nodeNo);
		if(node.getSpeculates() != null && !node.getSpeculates().isEmpty()) {
			return node.getSpeculates();
		}
		else {
			Instances nodeInstances = node.getInstances();
			Attribute classAttribute = nodeInstances.classAttribute();
			String classAttributeName = classAttribute.name();
			//开启线程池
			ExecutorService threadPool = Executors.newFixedThreadPool(threadNum > 0 ? threadNum: 2);
			final CountDownLatch count = new CountDownLatch(nodeInstances.numAttributes()-1);
			for (int i = 0; i < nodeInstances.numAttributes(); i++) {
				if(i != nodeInstances.classIndex()) {
					//获取属性信息
					final Attribute feature = nodeInstances.attribute(i);
					ArrayList<Attribute> attributes = TreeUtil.getAttributes(nodeInstances, feature.name(), classAttributeName);
					//生成数据集
					final Instances normalInstances = new Instances("Instances", attributes,  nodeInstances.size() / 2 + 1);
					normalInstances.setClass(normalInstances.attribute(classAttributeName));
					final Instances missingInstances = new Instances("Instances", attributes,  nodeInstances.size() / 2 + 1);
					missingInstances.setClass(missingInstances.attribute(classAttributeName));
					//添加数据
					for(int j=0; j< nodeInstances.size(); j++) {
						Instance item =  nodeInstances.get(j);
						Instance inst = new DenseInstance(attributes.size());
						boolean hasMissingValue = false;
						for(Attribute attribute : attributes) {
							Attribute rawAttr =  nodeInstances.attribute(attribute.name());
							if(item.isMissing(rawAttr)) {
								inst.setValue(attribute, Double.NaN);
								hasMissingValue = true;
							}
							else if(rawAttr.isNumeric()) {
								inst.setValue(attribute, item.value(rawAttr));
							}
							else {
								inst.setValue(attribute, item.stringValue(rawAttr));
							}
						}
						if(!hasMissingValue) {
							normalInstances.add(inst);
						}
						else {
							missingInstances.add(inst);
						}
					}
					//提交任务处理分箱
					threadPool.execute(new Runnable() {
						@Override
						public void run() {
							try {
								//构建树
								WekaREPTree normalWekaREPTree = new WekaREPTree();
								normalWekaREPTree.setNoPruning(true);
								normalWekaREPTree.setMaxDepth(maxDepth > 0 ? maxDepth : 3);
								normalWekaREPTree.buildClassifier(normalInstances, false);
								ClassifierTree normalClassifierTree = TreeUtil.toClassifierTree(normalWekaREPTree.getTree());
								TreeUtil.indexTreeNo(normalClassifierTree);
								//缺失值单独成箱
								ClassifierTree missingClassifierTree = null;
								if(missingInstances.size() > 0) {
									WekaREPTree missingWekaREPTree = new WekaREPTree();
									missingWekaREPTree.setNoPruning(true);
									missingWekaREPTree.setMaxDepth(0);
									missingWekaREPTree.buildClassifier(missingInstances, false);
									missingClassifierTree = TreeUtil.toClassifierTree(missingWekaREPTree.getTree());
									TreeUtil.indexTreeNo(missingClassifierTree);
								}
								//获取分箱
								BinPoints binPoints = TreeUtil.getBinPointBySingleFeatureTree(feature, normalClassifierTree, missingClassifierTree);
								//获取叶子节点
								List<ClassifierTree> leafNodes = TreeUtil.getLeafNodes(normalClassifierTree, missingClassifierTree);
								//获取信息增益
								double gain = TreeUtil.gainBySingleFeatureTreeLeaf(feature, leafNodes);
								//组织返回结果
								FeatureSpeculate speculate = new FeatureSpeculate();
								speculate.setGain(gain);
								speculate.setBinPoints(binPoints);
								speculate.setFeatureName(feature.name());
								if(feature.isNominal()) {
									speculate.setFeatureType(Constant.NOMINAL);
								}
								else {
									speculate.setFeatureType(Constant.NUMERIC);
								}
								speculates.add(speculate);
							}
							catch (Exception e) {
								e.printStackTrace();
							}
							finally {
								count.countDown();
							}
						}
					});
				}
			}
			// 等待任务执行完毕
			try {
				count.await(10, TimeUnit.MINUTES);
			} 
			catch (Exception e) {
				throw e;
			}
			finally {
				threadPool.shutdown();
			}
			//数据排序
			Collections.sort(speculates, new Comparator<FeatureSpeculate>() {
				@Override
				public int compare(FeatureSpeculate o1, FeatureSpeculate o2) {
					return o2.getGain().compareTo(o1.getGain());
				}
			});
			//将数据存入缓存中
			node.setSpeculates(speculates);
			//返回数据
			return speculates;
		}
	}

	@Override
	public boolean manualGrowthByNodeId(Integer nodeNo, String feature)throws Exception {
		return manualGrowthByNodeId(nodeNo, feature, defaultGrowthMaxDepth);
	}
	
	@Override
	public boolean manualGrowthByNodeId(Integer nodeNo, String feature, int maxDepth) throws Exception{
		//获取节点数据
		ClassifierTree thisNode = TreeUtil.getNodeByTreeNo(this.tree, nodeNo);
		if(thisNode.getSpeculates() != null && !thisNode.getSpeculates().isEmpty()) {
			BinPoints binPoints = null;
			List<FeatureSpeculate> featureSpeculates = thisNode.getSpeculates();
			for(FeatureSpeculate featureSpeculate : featureSpeculates) {
				if(feature.equals(featureSpeculate.getFeatureName())) {
					binPoints = featureSpeculate.getBinPoints();
					break;
				}
			}
			if(binPoints != null && !binPoints.isEmpty()) {
				return manualEditByNodeId(nodeNo, feature, binPoints);
			}
			else {
				return false;
			}
		}
		else {
			return false;
		}
	}

	@Override
	public FeatureBinPoints getBinPointsByNodeId(Integer nodeNo) {
		FeatureBinPoints binPoints = null;
		//获取节点数据
		ClassifierTree thisNode = TreeUtil.getNodeByTreeNo(this.tree, nodeNo);
		if(thisNode.getSplitAttribute() >= 0) {
			Attribute feature = thisNode.getInstances().attribute(thisNode.getSplitAttribute());
			binPoints = new FeatureBinPoints();
			binPoints.setFeatureName(feature.name());
			if(feature.isNominal()) {
				binPoints.setFeatureType(Constant.NOMINAL);
			}
			else {
				binPoints.setFeatureType(Constant.NUMERIC);
			}
			binPoints.setBinPoints(thisNode.getSplitPoints());
		}
		return binPoints;
	}

	@Override
	public boolean manualEditByNodeId(Integer nodeNo, String feature, BinPoints binPoints) throws Exception {
		//获取节点数据
		ClassifierTree thisNode = TreeUtil.getNodeByTreeNo(this.tree, nodeNo);
		Instances nodeInstances = thisNode.getInstances();
		//获取分裂特征
		Attribute attribute = null;
		if(feature != null && !feature.isEmpty()) {
			attribute = nodeInstances.attribute(feature);
		}
		if(attribute == null) {
			if(thisNode.getSplitAttribute() >= 0) {
	        	attribute= nodeInstances.attribute(thisNode.getSplitAttribute());
	        }
	        else {
	        	attribute = nodeInstances.attribute(feature);
	        }
		}
		//特征为空即返回失败
		if(attribute == null) {return false;}
		//定义子节点的数据
		Instances childrenInstances [] ;
		// 分离散与连续特征分布处理
		if(attribute.isNominal()) {
			childrenInstances = new Instances[binPoints.size()];
		}
		else {
			childrenInstances = new Instances[binPoints.size() + 1];
		}
		//初始化childrenInstances
		for(int i=0; i<childrenInstances.length;i++) {
        	childrenInstances[i] = new Instances(nodeInstances, 0);
        }
		//将数据写入各个Instance中
		for (Instance instance : nodeInstances) {
			int binNum = TreeUtil.getBinNumForAttribute(attribute, instance, binPoints);
			childrenInstances[binNum].add(instance);
		}
		//创建节点
		List<ClassifierTree> leafNodes = new ArrayList<>();
		for(Instances instances : childrenInstances) {
			ClassifierTree leafNode = null;
			if(instances.size() > 0) {
				WekaREPTree wekaREPTree = new WekaREPTree();
				wekaREPTree.setNoPruning(true);
				wekaREPTree.setMaxDepth(0);
				wekaREPTree.buildClassifier(instances, true);
				leafNode = TreeUtil.toClassifierTree(wekaREPTree.getTree());
			}
			else {
				leafNode = TreeUtil.getEmptyClassifierTree(instances);
			}
			leafNodes.add(leafNode);
		}
		//绑定节点
		thisNode.setChildNodes(leafNodes.toArray(new ClassifierTree[leafNodes.size()]));
		thisNode.setSplitPoints(binPoints);
		thisNode.setSplitAttribute(attribute.index());
		thisNode.setInstancesProps(TreeUtil.getInstancesPropsByNodeLeaf(nodeInstances.size(), leafNodes));
		//索引编号
		TreeUtil.indexTreeNo(this.tree);
		//返回
		return true;
	}
	
	@Override
	public void manualPruneByNodeId(Integer nodeNo) throws Exception {
		//获取节点数据
		ClassifierTree thisNode = TreeUtil.getNodeByTreeNo(this.tree, nodeNo);
		thisNode.setChildNodes(null);
		thisNode.setSplitAttribute(-1);
		thisNode.setSplitPoints(null);
		thisNode.setInstancesProps(null);
		//索引编号
		TreeUtil.indexTreeNo(this.tree);
	}
	
	@Override
	public List<RuleInfo> getRuleInfoForleafNodes() {
		return TreeUtil.getRuleInfoForleafNodes(this.instanceCount, this.overdueLabelIndex, this.countOverdueRate, this.tree);
	}
	
	@Override
	public String toDotGraph() throws Exception {
		if (this.tree == null) {
            throw new Exception("InteractiveDecisionTree: No model built yet.");
        }
        String result = 
        		"digraph Tree {\n" 
        				+ "edge [style=bold]\n"
        				+ "node [style=\"filled\" fillcolor=\"#ffffff\" fontname=\"Monospace\" width=3.5 height=1 fixedsize=false]\n"
        				+ this.tree.toGraph(this.instanceCount) 
                + "\n}\n";
        return result;
	}
	
}
